import os
import json
import pickle
import shutil

dataset_name = 'reddit'

base_folder = "enter-base-folder-here"

model_bank_folder = os.path.join(base_folder, "intermediate_results", dataset_name, "model_bank")
interpolation_folder = os.path.join(base_folder, "intermediate_results", dataset_name, "interpolated_models") 
routines_folder = os.path.join(base_folder, "experiment_routines", dataset_name)
final_results_folder = os.path.join(base_folder, "final_results", dataset_name)

samples_dict_file_name = "samples_dict.json"
interpolation_dict_file_name = "interpolation_dict.json"
routines_file_name = "experiment_routines.json"

final_result_types = ["cross_entropy", "generations"]

def initialize_corpus(corpus_name):
    bank_output_folder = os.path.join(model_bank_folder, corpus_name)
    if not os.path.exists(bank_output_folder):
        os.mkdir(bank_output_folder)
    
    interpolation_output_folder = os.path.join(interpolation_folder, corpus_name)
    if not os.path.exists(interpolation_output_folder):
        os.mkdir(interpolation_output_folder)

    final_output_folder = os.path.join(final_results_folder, corpus_name)
    if not os.path.exists(final_output_folder):
        os.mkdir(final_output_folder)
    
    return bank_output_folder, interpolation_output_folder, final_output_folder


def initialize_base_model(bank_output_folder, model_name):
    model_folder = os.path.join(bank_output_folder, model_name)
    if not os.path.exists(model_folder):
        os.mkdir(model_folder)
    model_samples_dict_file = os.path.join(model_folder, samples_dict_file_name)
    if os.path.exists(model_samples_dict_file):
        with open(model_samples_dict_file, 'r') as file:
            samples_dict = json.load(file)
    else:
        with open(model_samples_dict_file, 'w') as file:
            json.dump({}, file)

def initialize_base_models(bank_output_folder, model_names):
    for model_name in model_names:
        initialize_base_model(bank_output_folder, model_name)

def check_if_identical(list1, list2, ordered=True):
    if len(list1) != len(list2):
        return False

    for i in range(len(list1)):
        if ordered:
            if list1[i] != list2[i]:
                return False
        else:
            if list1[i] not in list2:
                return False

    return True

def collect_data(data_idx, data_list, n_samples_list):
    collected_data = []
    for i in range(len(n_samples_list)):
        sample_data = []
        for data in data_list:
            sample_data.append(data[data_idx][i])
        collected_data.append(sample_data)
    return collected_data


### ROUTINES RELATED

def load_routine(corpus_name, routine_id):
    routine_file = os.path.join(routines_folder, corpus_name, routines_file_name)
    with open(routine_file, 'r') as file:
        routines_dict = json.load(file)
    return routines_dict[routine_id]['model_names'], routines_dict[routine_id]['samples_list']

### FINETUNING RELATED

def check_if_models_exist(bank_output_folder, model_names, corpus_ids):
    found = []

    for model_name in model_names:
        model_folder = os.path.join(bank_output_folder, model_name)
        model_samples_dict_file = os.path.join(model_folder, samples_dict_file_name)
        with open(model_samples_dict_file, 'r') as file:
            samples_dict = json.load(file)
        samples_folder_names_list = list(samples_dict.keys())
        len_found_before = len(found)
        for key in samples_folder_names_list:
            samples_list = samples_dict[key]
            if check_if_identical(samples_list, corpus_ids):
                found.append(os.path.join(model_folder, key))
                break
        if len(found) == len_found_before:
            found.append(None)
    return found

def save_finetuned_model_data(bank_output_folder, model_name, corpus_ids, data):
    model_folder = os.path.join(bank_output_folder, model_name)
    model_samples_dict_file = os.path.join(model_folder, samples_dict_file_name)
    with open(model_samples_dict_file, 'r') as file:
        samples_dict = json.load(file)
        
    folder_id = 0
    samples_folder_names_list = list(samples_dict.keys())
    
    samples_folder_name = "samples_list_"+str(folder_id)
    while True:
        samples_folder_name = "samples_list_"+str(folder_id)
        if samples_folder_name in samples_folder_names_list:
            folder_id += 1
        else:
            break
    
    curr_samples_folder = os.path.join(model_folder, samples_folder_name)
    if not os.path.exists(curr_samples_folder):
        os.mkdir(curr_samples_folder)
    samples_dict[samples_folder_name] = corpus_ids
    
    with open(os.path.join(curr_samples_folder, 'finetuning_data.pkl'), 'wb') as file:
        pickle.dump(data, file)
        
    with open(model_samples_dict_file, 'w') as file:
            json.dump(samples_dict, file)


### INTERPOLATION RELATED ###

def initialize_interpolation(interpolation_output_folder, int_type):
    curr_interpolation_folder = os.path.join(interpolation_output_folder, int_type)
    curr_interpolation_file = os.path.join(curr_interpolation_folder, interpolation_dict_file_name)

    if not os.path.exists(curr_interpolation_folder):
        os.mkdir(curr_interpolation_folder)
        with open(curr_interpolation_file, 'w') as file:
            json.dump({}, file)

def check_if_interpolation_exists(interpolation_output_folder, model_names, corpus_ids, int_type):
    found = None
    curr_interpolation_folder = os.path.join(interpolation_output_folder, int_type)
    interpolation_dict_file = os.path.join(curr_interpolation_folder, interpolation_dict_file_name)
    with open(interpolation_dict_file, 'r') as file:
        interpolation_dict = json.load(file)
    interpolation_folder_names_list = list(interpolation_dict.keys())

    for key in interpolation_folder_names_list:
        models_list = interpolation_dict[key]["model_names"]
        samples_list = interpolation_dict[key]["samples_list"]
        if check_if_identical(samples_list, corpus_ids) and check_if_identical(model_names, models_list, ordered=False):
            found = os.path.join(curr_interpolation_folder, key)
            break
    return found

def save_interpolation_data(interpolation_output_folder, model_names, corpus_ids, int_type, data):

    curr_interpolation_folder = os.path.join(interpolation_output_folder, int_type)
    curr_interpolation_file = os.path.join(curr_interpolation_folder, interpolation_dict_file_name)

    with open(curr_interpolation_file, 'r') as file:
        interpolation_dict = json.load(file)

    folder_id = 0
    interpolation_folder_names_list = list(interpolation_dict.keys())
    
    interpolation_folder_name = "interpolation_"+str(folder_id)
    print(interpolation_folder_names_list)
    
    while True:
        interpolation_folder_name = "interpolation_"+str(folder_id)
        if interpolation_folder_name in interpolation_folder_names_list:
            folder_id += 1
        else:
            break
    print(interpolation_folder_name)
    interp_folder = os.path.join(curr_interpolation_folder, interpolation_folder_name)
    if not os.path.exists(interp_folder):
        os.mkdir(interp_folder)
    
    interpolation_dict[interpolation_folder_name] = {}
    interpolation_dict[interpolation_folder_name]['model_names'] = model_names
    interpolation_dict[interpolation_folder_name]['samples_list'] = corpus_ids
    
    with open(os.path.join(interp_folder, 'interpolation_data.pkl'), 'wb') as file:
        pickle.dump(data, file)
        
    with open(curr_interpolation_file, 'w') as file:
        json.dump(interpolation_dict, file)


### FINAL RESULTS RELATED ###

def initialize_final_results(final_output_folder, interpolation_output_folder, int_type, interpolation_id):
    curr_final_folder = os.path.join(final_output_folder, int_type)
    if not os.path.exists(curr_final_folder):
        os.mkdir(curr_final_folder)
#    curr_final_file = os.path.join(curr_final_folder, interpolation_dict_file_name)
    interpolation_file = os.path.join(interpolation_output_folder, int_type, interpolation_dict_file_name)
    shutil.copy(interpolation_file, curr_final_folder)

    output_folder = os.path.join(curr_final_folder, interpolation_id)
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    return output_folder
    
def save_pickle_file(path, data):
    with open(path, 'wb') as file:
        pickle.dump(data, file)



